from functools import partial

import numpy as onp

from autograd.extend import SparseObject, VJPNode, defvjp, defvjp_argnum, primitive, register_notrace, vspace

from ..util import func
from . import numpy_wrapper as anp
from .numpy_boxes import ArrayBox

# ----- Non-differentiable functions -----

nograd_functions = [
    anp.floor,
    anp.ceil,
    anp.round,
    anp.rint,
    anp.around,
    anp.fix,
    anp.trunc,
    anp.all,
    anp.any,
    anp.argmax,
    anp.argmin,
    anp.argpartition,
    anp.argsort,
    anp.argwhere,
    anp.nonzero,
    anp.flatnonzero,
    anp.count_nonzero,
    anp.searchsorted,
    anp.sign,
    anp.ndim,
    anp.shape,
    anp.floor_divide,
    anp.logical_and,
    anp.logical_or,
    anp.logical_not,
    anp.logical_xor,
    anp.isfinite,
    anp.isinf,
    anp.isnan,
    anp.isneginf,
    anp.isposinf,
    anp.allclose,
    anp.isclose,
    anp.array_equal,
    anp.array_equiv,
    anp.greater,
    anp.greater_equal,
    anp.less,
    anp.less_equal,
    anp.equal,
    anp.not_equal,
    anp.iscomplexobj,
    anp.iscomplex,
    anp.size,
    anp.isscalar,
    anp.isreal,
    anp.zeros_like,
    anp.ones_like,
    anp.result_type,
]

for fun in nograd_functions:
    register_notrace(VJPNode, fun)

# ----- Functions that are constant w.r.t. continuous inputs -----

defvjp(anp.nan_to_num, lambda ans, x: lambda g: anp.where(anp.isfinite(x), g, 0.0))

# ----- Binary ufuncs -----

defvjp(
    anp.add, lambda ans, x, y: unbroadcast_f(x, lambda g: g), lambda ans, x, y: unbroadcast_f(y, lambda g: g)
)
defvjp(
    anp.multiply,
    lambda ans, x, y: unbroadcast_f(x, lambda g: y * g),
    lambda ans, x, y: unbroadcast_f(y, lambda g: x * g),
)
defvjp(
    anp.subtract,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g),
    lambda ans, x, y: unbroadcast_f(y, lambda g: -g),
)
defvjp(
    anp.divide,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g / y),
    lambda ans, x, y: unbroadcast_f(y, lambda g: -g * x / y**2),
)
defvjp(
    anp.maximum,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)),
    lambda ans, x, y: unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)),
)
defvjp(
    anp.minimum,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)),
    lambda ans, x, y: unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)),
)
defvjp(
    anp.fmax,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)),
    lambda ans, x, y: unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)),
)
defvjp(
    anp.fmin,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g * balanced_eq(x, ans, y)),
    lambda ans, x, y: unbroadcast_f(y, lambda g: g * balanced_eq(y, ans, x)),
)
defvjp(
    anp.logaddexp,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g * anp.exp(x - ans)),
    lambda ans, x, y: unbroadcast_f(y, lambda g: g * anp.exp(y - ans)),
)
defvjp(
    anp.logaddexp2,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g * 2 ** (x - ans)),
    lambda ans, x, y: unbroadcast_f(y, lambda g: g * 2 ** (y - ans)),
)
defvjp(
    anp.true_divide,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g / y),
    lambda ans, x, y: unbroadcast_f(y, lambda g: -g * x / y**2),
)
defvjp(
    anp.mod,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g),
    lambda ans, x, y: unbroadcast_f(y, lambda g: -g * anp.floor(x / y)),
)
defvjp(
    anp.remainder,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g),
    lambda ans, x, y: unbroadcast_f(y, lambda g: -g * anp.floor(x / y)),
)
defvjp(
    anp.power,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g * y * x ** anp.where(y, y - 1, 1.0)),
    lambda ans, x, y: unbroadcast_f(y, lambda g: g * anp.log(replace_zero(x, 1.0)) * ans),
)
defvjp(
    anp.arctan2,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g * y / (x**2 + y**2)),
    lambda ans, x, y: unbroadcast_f(y, lambda g: g * -x / (x**2 + y**2)),
)
defvjp(
    anp.hypot,
    lambda ans, x, y: unbroadcast_f(x, lambda g: g * x / ans),
    lambda ans, x, y: unbroadcast_f(y, lambda g: g * y / ans),
)

# ----- Simple grads -----

defvjp(anp.negative, lambda ans, x: lambda g: -g)
defvjp(anp.abs, lambda ans, x: lambda g: g * replace_zero(anp.conj(x), 0.0) / replace_zero(ans, 1.0))
defvjp(anp.fabs, lambda ans, x: lambda g: anp.sign(x) * g)  # fabs doesn't take complex numbers.
defvjp(anp.absolute, lambda ans, x: lambda g: g * anp.conj(x) / ans)
defvjp(anp.reciprocal, lambda ans, x: lambda g: -g / x**2)
defvjp(anp.exp, lambda ans, x: lambda g: ans * g)
defvjp(anp.exp2, lambda ans, x: lambda g: ans * anp.log(2) * g)
defvjp(anp.expm1, lambda ans, x: lambda g: (ans + 1) * g)
defvjp(anp.log, lambda ans, x: lambda g: g / x)
defvjp(anp.log2, lambda ans, x: lambda g: g / x / anp.log(2))
defvjp(anp.log10, lambda ans, x: lambda g: g / x / anp.log(10))
defvjp(anp.log1p, lambda ans, x: lambda g: g / (x + 1))
defvjp(anp.sin, lambda ans, x: lambda g: g * anp.cos(x))
defvjp(anp.cos, lambda ans, x: lambda g: -g * anp.sin(x))
defvjp(anp.tan, lambda ans, x: lambda g: g / anp.cos(x) ** 2)
defvjp(anp.arcsin, lambda ans, x: lambda g: g / anp.sqrt(1 - x**2))
defvjp(anp.arccos, lambda ans, x: lambda g: -g / anp.sqrt(1 - x**2))
defvjp(anp.arctan, lambda ans, x: lambda g: g / (1 + x**2))
defvjp(anp.sinh, lambda ans, x: lambda g: g * anp.cosh(x))
defvjp(anp.cosh, lambda ans, x: lambda g: g * anp.sinh(x))
defvjp(anp.tanh, lambda ans, x: lambda g: g / anp.cosh(x) ** 2)
defvjp(anp.arcsinh, lambda ans, x: lambda g: g / anp.sqrt(x**2 + 1))
defvjp(anp.arccosh, lambda ans, x: lambda g: g / anp.sqrt(x**2 - 1))
defvjp(anp.arctanh, lambda ans, x: lambda g: g / (1 - x**2))
defvjp(anp.rad2deg, lambda ans, x: lambda g: g / anp.pi * 180.0)
defvjp(anp.degrees, lambda ans, x: lambda g: g / anp.pi * 180.0)
defvjp(anp.deg2rad, lambda ans, x: lambda g: g * anp.pi / 180.0)
defvjp(anp.radians, lambda ans, x: lambda g: g * anp.pi / 180.0)
defvjp(anp.square, lambda ans, x: lambda g: g * 2 * x)
defvjp(anp.sqrt, lambda ans, x: lambda g: g * 0.5 * x**-0.5)
defvjp(
    anp.sinc,
    lambda ans, x: lambda g: g * (anp.cos(anp.pi * x) * anp.pi * x - anp.sin(anp.pi * x)) / (anp.pi * x**2),
)
defvjp(anp.reshape, lambda ans, x, shape, order=None: lambda g: anp.reshape(g, anp.shape(x), order=order))
defvjp(anp.roll, lambda ans, x, shift, axis=None: lambda g: anp.roll(g, -shift, axis=axis))
defvjp(anp.array_split, lambda ans, ary, idxs, axis=0: lambda g: anp.concatenate(g, axis=axis))
defvjp(anp.split, lambda ans, ary, idxs, axis=0: lambda g: anp.concatenate(g, axis=axis))
defvjp(anp.vsplit, lambda ans, ary, idxs: lambda g: anp.concatenate(g, axis=0))
defvjp(anp.hsplit, lambda ans, ary, idxs: lambda g: anp.concatenate(g, axis=1))
defvjp(anp.dsplit, lambda ans, ary, idxs: lambda g: anp.concatenate(g, axis=2))
defvjp(anp.ravel, lambda ans, x, order=None: lambda g: anp.reshape(g, anp.shape(x), order=order))
defvjp(anp.expand_dims, lambda ans, x, axis: lambda g: anp.reshape(g, anp.shape(x)))
defvjp(anp.squeeze, lambda ans, x, axis=None: lambda g: anp.reshape(g, anp.shape(x)))
defvjp(anp.diag, lambda ans, x, k=0: lambda g: anp.diag(g, k))
defvjp(anp.flipud, lambda ans, x,: lambda g: anp.flipud(g))
defvjp(anp.fliplr, lambda ans, x,: lambda g: anp.fliplr(g))
defvjp(anp.rot90, lambda ans, x, k=1: lambda g: anp.rot90(g, -k))
defvjp(
    anp.trace,
    lambda ans, x, offset=0: lambda g: anp.einsum(
        "ij,...->ij...", anp.eye(x.shape[0], x.shape[1], k=offset), g
    ),
)
defvjp(anp.full, lambda ans, shape, fill_value, dtype=None: lambda g: anp.sum(g), argnums=(1,))
defvjp(anp.triu, lambda ans, x, k=0: lambda g: anp.triu(g, k=k))
defvjp(anp.tril, lambda ans, x, k=0: lambda g: anp.tril(g, k=k))
defvjp(anp.clip, lambda ans, x, a_min, a_max: lambda g: g * anp.logical_and(ans != a_min, ans != a_max))
defvjp(anp.swapaxes, lambda ans, x, axis1, axis2: lambda g: anp.swapaxes(g, axis2, axis1))
defvjp(anp.moveaxis, lambda ans, a, source, destination: lambda g: anp.moveaxis(g, destination, source))
defvjp(anp.real_if_close, lambda ans, x: lambda g: match_complex(x, g))
defvjp(anp.real, lambda ans, x: lambda g: match_complex(x, g))
defvjp(anp.imag, lambda ans, x: lambda g: match_complex(x, -1j * g))
defvjp(anp.conj, lambda ans, x: lambda g: anp.conj(g))
defvjp(anp.conjugate, lambda ans, x: lambda g: anp.conj(g))
defvjp(anp.angle, lambda ans, x: lambda g: match_complex(x, g * anp.conj(x * 1j) / anp.abs(x) ** 2))
defvjp(
    anp.where,
    None,
    lambda ans, c, x=None, y=None: lambda g: anp.where(c, g, anp.zeros(g.shape)),
    lambda ans, c, x=None, y=None: lambda g: anp.where(c, anp.zeros(g.shape), g),
)
defvjp(
    anp.cross,
    lambda ans, a, b, axisa=-1, axisb=-1, axisc=-1, axis=None: lambda g: anp.cross(
        b, g, axisb, axisc, axisa, axis
    ),
    lambda ans, a, b, axisa=-1, axisb=-1, axisc=-1, axis=None: lambda g: anp.cross(
        g, a, axisc, axisa, axisb, axis
    ),
)
defvjp(
    anp.linspace,
    lambda ans, start, stop, num: lambda g: anp.dot(anp.linspace(1.0, 0.0, num), g),
    lambda ans, start, stop, num: lambda g: anp.dot(anp.linspace(0.0, 1.0, num), g),
)

defvjp(
    anp._astype,
    lambda ans, A, dtype, order="K", casting="unsafe", subok=True, copy=True: lambda g: anp._astype(
        g, A.dtype
    ),
)


# ----- Trickier grads -----
def grad_rollaxis(ans, a, axis, start=0):
    if axis < 0:
        raise NotImplementedError(
            "Gradient of rollaxis not implemented for axis < 0. Please use moveaxis instead."
        )
    elif start < 0:
        raise NotImplementedError(
            "Gradient of rollaxis not implemented for start < 0. Please use moveaxis instead."
        )
    return lambda g: anp.rollaxis(g, start - 1, axis) if start > axis else anp.rollaxis(g, start, axis + 1)


defvjp(anp.rollaxis, grad_rollaxis)


def grad_diff(ans, a, n=1, axis=-1):
    nd = anp.ndim(a)
    ans_shape = anp.shape(ans)
    sl1 = [slice(None)] * nd
    sl1[axis] = slice(None, 1)

    sl2 = [slice(None)] * nd
    sl2[axis] = slice(-1, None)

    def undiff(g):
        if g.shape[axis] > 0:
            return anp.concatenate((-g[tuple(sl1)], -anp.diff(g, axis=axis), g[tuple(sl2)]), axis=axis)
        shape = list(ans_shape)
        shape[axis] = 1
        return anp.zeros(shape)

    def helper(g, n):
        if n == 0:
            return g
        return helper(undiff(g), n - 1)

    return lambda g: helper(g, n)


defvjp(anp.diff, grad_diff)


def grad_gradient(ans, x, *vargs, **kwargs):
    axis = kwargs.pop("axis", None)
    if vargs or kwargs:
        raise NotImplementedError("The only optional argument currently supported for np.gradient is axis.")
    if axis is None:
        axis = range(x.ndim)
    elif type(axis) is int:
        axis = [axis]
    else:
        axis = list(axis)

    x_dtype = x.dtype
    x_shape = x.shape
    nd = x.ndim

    def vjp(g):
        if anp.ndim(g) == nd:
            # add axis if gradient was along one axis only
            g = g[anp.newaxis]

        # accumulate gradient
        out = anp.zeros(x_shape, dtype=x_dtype)

        for i, a in enumerate(axis):
            # swap gradient axis to the front
            g_swap = anp.swapaxes(g[i], 0, a)[:, anp.newaxis]

            out_axis = anp.concatenate(
                (
                    -g_swap[0] - 0.5 * g_swap[1],
                    g_swap[0] - 0.5 * g_swap[2],
                    (-1.0) * anp.gradient(g_swap, axis=0)[2:-2, 0],
                    0.5 * g_swap[-3] - g_swap[-1],
                    0.5 * g_swap[-2] + g_swap[-1],
                ),
                axis=0,
            )

            out = out + anp.swapaxes(out_axis, 0, a)

        return out

    return vjp


defvjp(anp.gradient, grad_gradient)


def grad_repeat(ans, x, repeats, axis=None):
    shape = anp.shape(x)

    def vjp(g):
        if axis is None:  # If axis is none, np.repeat() repeats the flattened array.
            expanded = anp.reshape(g, (anp.prod(shape),) + (repeats,))
            return anp.reshape(anp.sum(expanded, axis=1, keepdims=False), shape)
        else:
            if shape[axis] == 1:  # For this common case, the logic is simple.
                return anp.sum(g, axis=axis, keepdims=True)
            else:
                expanded = anp.reshape(g, shape[0 : axis + 1] + (repeats,) + shape[axis + 1 :])
                return anp.sum(expanded, axis=axis + 1, keepdims=False)

    return vjp


defvjp(anp.repeat, grad_repeat)


def grad_tile(ans, x, reps):
    reps = [reps] if anp.isscalar(reps) else reps
    x_shape = anp.shape(x)

    def vjp(g):
        for axis, rep in enumerate(reps):
            g = sum(anp.split(g, rep, axis))
        return anp.reshape(g, x_shape)

    return vjp


defvjp(anp.tile, grad_tile)


def grad_kron(argnum, ans, orig_A, orig_B):
    # kron has different promotion rules than dot. the reshapes are necessary if
    # and only if (1) orig_B is 1D or (2) orig_A and/or orig_B are 0D
    orig_A_shape = anp.shape(orig_A)
    orig_B_shape = anp.shape(orig_B)

    def vjp(G):
        A, B = anp.atleast_2d(orig_A), anp.atleast_2d(orig_B)
        shape = list(A.shape + B.shape)
        n = anp.ndim(A)
        shape[n - 1], shape[n] = shape[n], shape[n - 1]
        reshaped_G = anp.swapaxes(anp.reshape(G, shape), n - 1, n)
        if argnum == 0:
            return match_complex(
                orig_A, anp.reshape(anp.tensordot(reshaped_G, B, axes=anp.ndim(B)), orig_A_shape)
            )
        else:
            return match_complex(
                orig_B, anp.reshape(anp.tensordot(A, reshaped_G, axes=anp.ndim(A)), orig_B_shape)
            )

    return vjp


defvjp(anp.kron, partial(grad_kron, 0), partial(grad_kron, 1))


def grad_transpose(ans, x, axes=None):
    if axes is not None:
        axes = anp.argsort(axes)
    return lambda g: anp.transpose(g, axes)


defvjp(anp.transpose, grad_transpose)


def repeat_to_match_shape(g, shape, dtype, axis, keepdims):
    """Returns the array g repeated along axis to fit vector space vs.
    Also returns the number of repetitions of the array."""
    if shape == ():
        return g, 1
    axis = list(axis) if isinstance(axis, tuple) else axis
    new_shape = onp.array(shape)
    new_shape[axis] = 1
    num_reps = onp.prod(onp.array(shape)[axis])
    # Can't use broadcast_to because of numpy bug: https://github.com/numpy/numpy/issues/9165
    # return anp.broadcast_to(anp.reshape(g, new_shape), shape), num_reps
    return anp.reshape(g, new_shape) + onp.zeros(shape, dtype=dtype), num_reps


def grad_broadcast_to(ans, x, new_shape):
    old_shape = anp.shape(x)
    assert anp.shape(ans) == new_shape
    assert len(old_shape) == len(new_shape), "Can't handle extra leading dims"
    broadcast_axes = tuple(
        onp.where(onp.logical_and(onp.array(old_shape) == 1, onp.array(new_shape) > 1))[0]
    )
    return lambda g: anp.sum(g, axis=broadcast_axes, keepdims=True)


defvjp(anp.broadcast_to, grad_broadcast_to)


def grad_np_sum(ans, x, axis=None, keepdims=False, dtype=None):
    shape, dtype = anp.shape(x), anp.result_type(x)
    return lambda g: repeat_to_match_shape(g, shape, dtype, axis, keepdims)[0]


defvjp(anp.sum, grad_np_sum)


def grad_np_mean(ans, x, axis=None, keepdims=False):
    shape, dtype = anp.shape(x), anp.result_type(x)

    def vjp(g):
        g_repeated, num_reps = repeat_to_match_shape(g, shape, dtype, axis, keepdims)
        return g_repeated / num_reps

    return vjp


defvjp(anp.mean, grad_np_mean)


def grad_np_prod(ans, x, axis=None, keepdims=False):  # TODO: Support tuples of axes.
    shape, dtype = anp.shape(x), anp.result_type(x)

    def vjp(g):
        g_repeated, _ = repeat_to_match_shape(g * ans, shape, dtype, axis, keepdims)
        return g_repeated / x

    return vjp


defvjp(anp.prod, grad_np_prod)


def grad_np_var(ans, x, axis=None, ddof=0, keepdims=False):
    shape, _, dtype, iscomplex = anp.metadata(x)

    def vjp(g):
        if iscomplex:
            g = g + 0j
        g_repeated, num_reps = repeat_to_match_shape(g, shape, dtype, axis, keepdims)
        x_minus_mean = anp.conj(x - anp.mean(x, axis=axis, keepdims=True))
        return 2.0 * g_repeated * x_minus_mean / (num_reps - ddof)

    return vjp


defvjp(anp.var, grad_np_var)


def grad_np_std(ans, x, axis=None, ddof=0, keepdims=False):
    shape, _, dtype, iscomplex = anp.metadata(x)

    def vjp(g):
        if iscomplex:
            g = g + 0j
        g_repeated, num_reps = repeat_to_match_shape(
            g, shape, dtype, axis, keepdims
        )  # Avoid division by zero.
        if num_reps <= 1:
            return g_repeated * 0.0
        else:
            g_repeated, num_reps = repeat_to_match_shape(g / ans, shape, dtype, axis, keepdims)
            x_minus_mean = anp.conj(x - anp.mean(x, axis=axis, keepdims=True))
            return g_repeated * x_minus_mean / (num_reps - ddof)

    return vjp


defvjp(anp.std, grad_np_std)


def grad_chooser(ans, x, axis=None, keepdims=None):
    shape, dtype = anp.shape(x), anp.result_type(x)

    def vjp(g):
        """Builds gradient of functions that choose a single item, such as min or max."""
        g_repeated, _ = repeat_to_match_shape(g, shape, dtype, axis, keepdims)
        argmax_locations = x == repeat_to_match_shape(ans, shape, dtype, axis, keepdims)[0]
        return g_repeated * argmax_locations / onp.sum(argmax_locations, axis=axis, keepdims=True)

    return vjp


defvjp(anp.max, grad_chooser)
defvjp(anp.min, grad_chooser)
defvjp(anp.amax, grad_chooser)
defvjp(anp.amin, grad_chooser)


def reverse_axis(x, axis):
    x = x.swapaxes(axis, 0)
    x = x[::-1, ...]
    return x.swapaxes(0, axis)


def grad_np_cumsum(ans, x, axis=None):
    def vjp(g):
        if axis:
            return reverse_axis(anp.cumsum(reverse_axis(g, axis), axis), axis)
        else:
            return anp.reshape(anp.cumsum(g[::-1], axis)[::-1], x.shape)

    return vjp


defvjp(anp.cumsum, grad_np_cumsum)


def grad_inner(argnum, ans, A, B):
    A_ndim, B_ndim = anp.ndim(A), anp.ndim(B)
    if A_ndim == 0 or B_ndim == 0:
        axes = ([], [])
    else:
        axes = ([A_ndim - 1], [B_ndim - 1])
    if argnum == 0:
        return lambda G: tensordot_adjoint_0(B, G, axes, A_ndim, B_ndim)
    elif argnum == 1:
        return lambda G: tensordot_adjoint_1(A, G, axes, A_ndim, B_ndim)


defvjp(anp.inner, partial(grad_inner, 0), partial(grad_inner, 1))


def matmul_adjoint_0(B, G, A_meta, B_ndim):
    if anp.ndim(G) == 0:  # A_ndim == B_ndim == 1
        return unbroadcast(G * B, A_meta)
    _, A_ndim, _, _ = A_meta
    if A_ndim == 1:
        G = anp.expand_dims(G, anp.ndim(G) - 1)
    if B_ndim == 1:  # The result we need is an outer product
        B = anp.expand_dims(B, 0)
        G = anp.expand_dims(G, anp.ndim(G))
    else:  # We need to swap the last two axes of B
        B = anp.swapaxes(B, B_ndim - 2, B_ndim - 1)
    result = anp.matmul(G, B)
    return unbroadcast(result, A_meta)


def matmul_adjoint_1(A, G, A_ndim, B_meta):
    if anp.ndim(G) == 0:  # A_ndim == B_ndim == 1
        return unbroadcast(G * A, B_meta)
    _, B_ndim, _, _ = B_meta
    B_is_vec = B_ndim == 1
    if B_is_vec:
        G = anp.expand_dims(G, anp.ndim(G))
    if A_ndim == 1:  # The result we need is an outer product
        A = anp.expand_dims(A, 1)
        G = anp.expand_dims(G, anp.ndim(G) - 1)
    else:  # We need to swap the last two axes of A
        A = anp.swapaxes(A, A_ndim - 2, A_ndim - 1)
    result = anp.matmul(A, G)
    if B_is_vec:
        result = anp.squeeze(result, anp.ndim(G) - 1)
    return unbroadcast(result, B_meta)


def matmul_vjp_0(ans, A, B):
    A_meta = anp.metadata(A)
    B_ndim = anp.ndim(B)
    return lambda g: matmul_adjoint_0(B, g, A_meta, B_ndim)


def matmul_vjp_1(ans, A, B):
    A_ndim = anp.ndim(A)
    B_meta = anp.metadata(B)
    return lambda g: matmul_adjoint_1(A, g, A_ndim, B_meta)


defvjp(anp.matmul, matmul_vjp_0, matmul_vjp_1)


@primitive
def dot_adjoint_0(B, G, A_meta, B_meta):
    _, A_ndim, A_dtype, _ = A_meta
    _, B_ndim, _, _ = B_meta
    if B_ndim == 0 or B_ndim == 1 or A_ndim == 0:
        contract_num = max(0, B_ndim - (A_ndim != 0))
        out = onp.tensordot(G, B, contract_num)
    else:
        out = onp.tensordot(G, onp.swapaxes(B, -1, -2), B_ndim - 1)
    return onp.asarray(out, dtype=A_dtype)


@primitive
def dot_adjoint_1(A, G, A_meta, B_meta):
    _, A_ndim, _, _ = A_meta
    _, B_ndim, B_dtype, _ = B_meta
    needs_transpose = B_ndim > 1 and A_ndim != 0
    swap = (lambda x: onp.swapaxes(x, -1, -2)) if needs_transpose else (lambda x: x)
    if A_ndim == 0 or A_ndim == 1 or B_ndim == 0:
        contract_num = max(0, A_ndim - (B_ndim != 0))
        out = swap(onp.tensordot(G, A, contract_num))
    else:
        out = swap(onp.tensordot(G, A, [range(-A_ndim - B_ndim + 2, -B_ndim + 1), range(A_ndim - 1)]))
    return onp.asarray(out, dtype=B_dtype)


def dot_vjp_0(ans, A, B):
    A_meta, B_meta = anp.metadata(A), anp.metadata(B)
    return lambda g: match_complex(A, dot_adjoint_0(B, g, A_meta, B_meta))


def dot_vjp_1(ans, A, B):
    A_meta, B_meta = anp.metadata(A), anp.metadata(B)
    return lambda g: match_complex(B, dot_adjoint_1(A, g, A_meta, B_meta))


defvjp(anp.dot, dot_vjp_0, dot_vjp_1)

defvjp(
    dot_adjoint_0,
    lambda ans, B, g, An, Bn: lambda A: match_complex(B, dot_adjoint_1(A, g, An, Bn)),
    lambda ans, B, g, An, Bn: lambda A: match_complex(g, anp.dot(A, B)),
)

defvjp(
    dot_adjoint_1,
    lambda ans, A, g, An, Bn: lambda B: match_complex(A, dot_adjoint_0(B, g, An, Bn)),
    lambda ans, A, g, An, Bn: lambda B: match_complex(g, anp.dot(A, B)),
)


@primitive
def tensordot_adjoint_0(B, G, axes, A_ndim, B_ndim):
    # The adjoint of the operator
    # A |--> np.tensordot(A, B, axes)
    if B_ndim == 0:
        return G * B

    G_axes = onp.arange(onp.ndim(G))
    if type(axes) is int:
        axes = max(axes, 0)
        B_axes = onp.arange(B_ndim)
        return onp.tensordot(G, B, [G_axes[A_ndim - axes :], B_axes[axes:]])
    else:
        axes0 = [axes[0]] if type(axes[0]) is int else axes[0]
        axes1 = [axes[1]] if type(axes[1]) is int else axes[1]
        axes = [axes0, axes1]
        A_axes = onp.arange(A_ndim)
        B_axes = onp.arange(B_ndim)
        summed_axes = [
            onp.asarray(axes[0], dtype="int64") % A_ndim,
            onp.asarray(axes[1], dtype="int64") % B_ndim,
        ]
        other_axes = [onp.delete(A_axes, summed_axes[0]), onp.delete(B_axes, summed_axes[1])]
        out = onp.tensordot(G, B, [G_axes[len(other_axes[0]) :], other_axes[1]])
        perm = onp.argsort(onp.concatenate((other_axes[0], summed_axes[0][onp.argsort(summed_axes[1])])))
        return onp.transpose(out, perm)


@primitive
def tensordot_adjoint_1(A, G, axes, A_ndim, B_ndim):
    # The adjoint of the operator
    # B |--> np.tensordot(A, B, axes)
    if A_ndim == 0:
        return G * A

    G_axes = onp.arange(onp.ndim(G))
    if type(axes) is int:
        axes = max(axes, 0)
        A_axes = onp.arange(A_ndim)
        return onp.tensordot(A, G, [A_axes[: A_ndim - axes], G_axes[: A_ndim - axes]])
    else:
        axes0 = [axes[0]] if type(axes[0]) is int else axes[0]
        axes1 = [axes[1]] if type(axes[1]) is int else axes[1]
        axes = [axes0, axes1]
        A_axes = onp.arange(A_ndim)
        B_axes = onp.arange(B_ndim)
        summed_axes = [
            onp.asarray(axes[0], dtype="int64") % A_ndim,
            onp.asarray(axes[1], dtype="int64") % B_ndim,
        ]
        other_axes = [onp.delete(A_axes, summed_axes[0]), onp.delete(B_axes, summed_axes[1])]
        out = onp.tensordot(A, G, [other_axes[0], G_axes[: len(other_axes[0])]])
        perm = onp.argsort(onp.concatenate((summed_axes[1][onp.argsort(summed_axes[0])], other_axes[1])))
        return onp.transpose(out, perm)


def tensordot_vjp_0(ans, A, B, axes=2):
    A_ndim, B_ndim = anp.ndim(A), anp.ndim(B)
    return lambda G: match_complex(A, tensordot_adjoint_0(B, G, axes, A_ndim, B_ndim))


def tensordot_vjp_1(ans, A, B, axes=2):
    A_ndim, B_ndim = anp.ndim(A), anp.ndim(B)
    return lambda G: match_complex(B, tensordot_adjoint_1(A, G, axes, A_ndim, B_ndim))


defvjp(anp.tensordot, tensordot_vjp_0, tensordot_vjp_1)
defvjp(
    tensordot_adjoint_0,
    lambda ans, B, G, axes, An, Bn: lambda A: match_complex(B, tensordot_adjoint_1(A, G, axes, An, Bn)),
    lambda ans, B, G, axes, An, Bn: lambda A: match_complex(G, anp.tensordot(A, B, axes)),
)
defvjp(
    tensordot_adjoint_1,
    lambda ans, A, G, axes, An, Bn: lambda B: match_complex(A, tensordot_adjoint_0(B, G, axes, An, Bn)),
    lambda ans, A, G, axes, An, Bn: lambda B: match_complex(G, anp.tensordot(A, B, axes)),
)
defvjp(
    anp.outer,
    lambda ans, a, b: lambda g: match_complex(a, anp.dot(g, b.T)),
    lambda ans, a, b: lambda g: match_complex(b, anp.dot(a.T, g)),
)


def grad_concatenate_args(argnum, ans, axis_args, kwargs):
    axis, args = axis_args[0], axis_args[1:]
    sizes = [anp.shape(a)[axis] for a in args[:argnum]]
    start = sum(sizes[:-1])
    idxs = [slice(None)] * ans.ndim
    idxs[axis] = slice(start, start + sizes[-1])
    return lambda g: g[tuple(idxs)]


defvjp_argnum(anp.concatenate_args, grad_concatenate_args)


def wrapped_reshape(x, *args, **kwargs):
    # The reshape method can be called like A.reshape((5,4)) or A.reshape(5,4).
    # The reshape function doesn't support both ways, so we have to wrap it.
    if isinstance(args[0], int):
        return anp.reshape(x, args, **kwargs)
    else:
        return anp.reshape(x, *args, **kwargs)


setattr(ArrayBox, "reshape", wrapped_reshape)


def grad_sort(ans, x, axis=-1, kind="quicksort", order=None):
    # TODO: Cast input with np.asanyarray()
    if len(x.shape) > 1:
        raise NotImplementedError("Gradient of sort not implemented for multi-dimensional arrays.")
    sort_perm = anp.argsort(x, axis, kind, order)
    return lambda g: unpermuter(g, sort_perm)


defvjp(anp.sort, grad_sort)
if onp.lib.NumpyVersion(onp.__version__) < "2.0.0":
    defvjp(anp.msort, grad_sort)  # Until multi-D is allowed, these are the same.


def grad_partition(ans, x, kth, axis=-1, kind="introselect", order=None):
    # TODO: Cast input with np.asanyarray()
    if len(x.shape) > 1:
        raise NotImplementedError("Gradient of partition not implemented for multi-dimensional arrays.")
    partition_perm = anp.argpartition(x, kth, axis, kind, order)
    return lambda g: unpermuter(g, partition_perm)


defvjp(anp.partition, grad_partition)


def unpermuter(g, permutation):
    unsort = anp.zeros(len(permutation), dtype=int)
    unsort[permutation] = list(range(len(permutation)))
    return g[unsort]


def grad_reshape_list(ans, *arys):
    if len(arys) > 1:
        raise NotImplementedError("Can't handle multiple arguments yet.")
    return lambda g: anp.reshape(g, anp.shape(arys[0]))


defvjp(anp.atleast_1d, grad_reshape_list)
defvjp(anp.atleast_2d, grad_reshape_list)
defvjp(anp.atleast_3d, grad_reshape_list)


def grad_einsum(argnum, ans, operands_, kwargs):
    result_meta = anp.metadata(operands_[argnum])

    def vjp(g):
        operands = operands_
        if isinstance(operands[0], str):  # using "ijk" convention.
            in_subs, out_subs, _ = anp.parse_einsum_input(*operands)
            string, operands = operands[0], operands[1:]

            in_subs_list = in_subs.split(",")
            op_num = argnum - 1
            subs_wrt = in_subs_list[op_num]
            rest_of_ops = operands[:op_num] + operands[op_num + 1 :]
            rest_of_subs = in_subs_list[:op_num] + in_subs_list[op_num + 1 :]

            # subscripts that only appear in subs_wrt (and not in other subscript lists
            # or in the output) are implicitly being summed out, as if contracted
            # against a tensor of ones. we make that tensor of ones explicit to handle
            # the necessary vjp broadcasting inside einsum.
            other_named_subs = set("".join([out_subs] + rest_of_subs))
            naked_summed = [(i, sub) for i, sub in enumerate(subs_wrt) if sub not in other_named_subs]
            if naked_summed:
                naked_summed_dims, ones_subs = zip(*naked_summed)
                ones_subs = "".join(ones_subs)
                ones = onp.ones(onp.array(operands[op_num].shape)[list(naked_summed_dims)])
                new_input_subs = ",".join([out_subs, ones_subs] + rest_of_subs)
                new_operands = (g, ones) + rest_of_ops
            else:
                new_input_subs = ",".join([out_subs] + rest_of_subs)
                new_operands = (g,) + rest_of_ops

            new_subscripts = new_input_subs + "->" + subs_wrt
            return unbroadcast(anp.einsum(new_subscripts, *new_operands), result_meta)
        else:  # using (op0, sublist0, op1, sublist1, ..., sublistout) convention
            if len(operands) % 2 == 0:
                raise NotImplementedError("Need sublistout argument")
            operands = list(operands)
            rest_of_ops = (
                [operands[-1]] + operands[:argnum] + operands[(argnum + 2) : -1] + [operands[argnum + 1]]
            )
            return unbroadcast_einsum(anp.einsum(g, *rest_of_ops), result_meta, operands[argnum + 1])

    return vjp


defvjp_argnum(anp.einsum, grad_einsum)

defvjp(
    anp.diagonal,
    lambda ans, A, offset=0, axis1=0, axis2=1: lambda g: anp.make_diagonal(g, offset, axis1, axis2),
)
defvjp(
    anp.make_diagonal,
    lambda ans, D, offset=0, axis1=0, axis2=1: lambda g: anp.diagonal(g, offset, axis1, axis2),
)


def match_complex(target, x):
    target_iscomplex = anp.iscomplexobj(target)
    x_iscomplex = anp.iscomplexobj(x)
    if x_iscomplex and not target_iscomplex:
        return anp.real(x)
    elif not x_iscomplex and target_iscomplex:
        return x + 0j
    else:
        return x


def unbroadcast(x, target_meta, broadcast_idx=0):
    target_shape, target_ndim, dtype, target_iscomplex = target_meta
    while anp.ndim(x) > target_ndim:
        x = anp.sum(x, axis=broadcast_idx)
    for axis, size in enumerate(target_shape):
        if size == 1:
            x = anp.sum(x, axis=axis, keepdims=True)
    if anp.iscomplexobj(x) and not target_iscomplex:
        x = anp.real(x)
    return x


def unbroadcast_f(target, f):
    target_meta = anp.metadata(target)
    return lambda g: unbroadcast(f(g), target_meta)


def unbroadcast_einsum(x, target_meta, subscript):
    if Ellipsis not in subscript:
        return x
    elif subscript[0] == Ellipsis:
        return unbroadcast(x, target_meta, 0)
    elif subscript[-1] == Ellipsis:
        return unbroadcast(x, target_meta, -1)
    else:
        return unbroadcast(x, target_meta, subscript.index(Ellipsis))


def balanced_eq(x, z, y):
    return (x == z) / (1.0 + (x == y))


def replace_zero(x, val):
    return anp.where(x, x, val)


# ----- extra functions used internally  -----


def array_from_args_gradmaker(argnum, ans, args, kwargs):
    return lambda g: g[argnum - 2]


defvjp_argnum(anp.array_from_args, array_from_args_gradmaker)


def array_from_scalar_or_array_gradmaker(ans, array_args, array_kwargs, scarray):
    ndmin = array_kwargs.get("ndmin", 0)
    scarray_ndim = anp.ndim(scarray)
    if ndmin > scarray_ndim:
        return lambda g: anp.squeeze(g, axis=tuple(range(ndmin - scarray_ndim)))
    else:
        return lambda g: g


defvjp(anp._array_from_scalar_or_array, array_from_scalar_or_array_gradmaker, argnums=(2, 3))


@primitive
def untake(x, idx, vs):
    if isinstance(idx, list) and (len(idx) == 0 or not isinstance(idx[0], slice)):
        idx = onp.array(idx, dtype="int64")

    def mut_add(A):
        onp.add.at(A, idx, x)
        return A

    return SparseObject(vs, mut_add)


defvjp(func(ArrayBox.__getitem__), lambda ans, A, idx: lambda g: untake(g, idx, vspace(A)))
defvjp(untake, lambda ans, x, idx, _: lambda g: g[idx])


def _unpad(array, width):
    if anp.isscalar(width):
        width = [[width, width]]
    elif anp.shape(width) == (1,):
        width = [anp.concatenate((width, width))]
    elif anp.shape(width) == (2,):
        width = [width]
    if anp.shape(width)[0] == 1:
        width = anp.repeat(width, anp.ndim(array), 0)
    idxs = tuple(slice(l, -u or None) for l, u in width)
    return array[idxs]


def pad_vjp(ans, array, pad_width, mode, **kwargs):
    assert mode == "constant", "Only constant mode padding is supported."
    return lambda g: _unpad(g, pad_width)


defvjp(anp.pad, pad_vjp)
